Skip to content

Conversation

@fynnsu
Copy link
Collaborator

@fynnsu fynnsu commented Oct 3, 2025

This pr introduces Eagle3 Model training into the speculators repo. The implementation is specific to Eagle3 but designed in a way that enables future generalization to other speculative decoding algorithms.

Components

Eagle3 Training Components

Example training script (scripts/train_llama3_8b_drafter.py scripts/train.py)

Shows how to setup and run training. Currently specific to the meta-llama/Llama-3.1-8B-Instruct model but doesn't require many changes to run with a different model. Just need to update
VERIFIER_MODEL_NAME_OR_PATH = "meta-llama/Llama-3.1-8B-Instruct" HIDDEN_SIZE = 4096 # Must match the verifier model's hidden size VERIFIER_VOCAB_SIZE = 128256 # Must match the verifier model's vocab size

Update: I've generalize the training script. It now has a required cli arg --verifier_name_or_path and supports arbitrary verifier models. Note: this uses LlamaConfig.from_pretrained(args.verifier_name_or_path) under the hood, which does work for non-llama models (e.g. a Qwen model) but prints a warning and may not work for every type of verifier.

You will also need to pass in a dataset and t2d / d2t tensors which correspond to the verifier you are using.

Flex Attention

Files:

  • src/speculators/train/eagle3/attention.py
  • tests/unit/train/test_eagle3_attention.py

The training code uses Flex attention which provides substantial speed ups and memory efficiency over the full dense attention operations.

Functions:

  • create_combined_mask_mod(lengths, total_seq_len): This function creates the mask function used by flex attention.
  • extend_mask_for_draft_tokens(block_mask): Helper function to extend the block mask without needed to check each new squares mask value
  • block_mask_to_dense_attention_mask: Only used for debugging purposes
  • flex_attention_forward: lightweight wrapper around flex attention call

Data processing

Eagle3 Data Flow Files: - `src/speculators/train/data.py`

Data is currently expected in the format of 1 file per data sample. We load these samples and perform a shift to align input_ids, hidden_states, loss_mask, verifier_last_hidden_state correctly. We also automatically collate these samples into batches. Rather than padding and wasting compute on padded tokens, we instead concatenate the sequences along the sequence dimension, keeping track of the boundaries between sequences and setting the attention mask accordingly.

Batch sampling

Files:

  • src/speculators/train/distributed_batch_sampler.py
  • src/speculators/train/data.py

Due to hardware limitations, we set a maximum sequence length for each batch. We would like each batch of data to be close in size this max length, so that each batch has a similar number of tokens. The way we achieve this is through the MultipackDistributedBatchSamplerV2 taken from prior work I did on instructlab/training. This class produces indices of files that when batched together come close to reaching the max length without exceeding it. It also does this in a distributed aware manner so that there is no overlap in the data each rank sees.

To run the packing algorithm, we need to know the lengths of each sample in the dataset. Unfortunately, this would require opening every file in the dataset which is expensive, so instead we approximate the lengths (_compute_approx_lengths in data.py) using the length of the first sample and the relative file sizes of samples.

Eagle3DraftModel

Files:

  • src/speculators/train/eagle3/core.py

The draft model itself. Sets up and loads verifier components, as well as the draft layers / weights. Contains the model forward() pass which:

  • sets up the block mask for the batch
  • computes the target logits using the attached verifier_lm_head. Note: this is computed here for data storage efficiency reasons, as otherwise we would need to save the full logits: [seq_len, vocab_size] instead of the last layer hidden states: [seq_len, hidden_size] to disk. The verifier vocab_size is often > 100k whereas hidden_size might be around 4-8k.
  • For each ttt step:
    • Embeds tokens
    • concatenates with hidden_states
    • applies decoder layers
    • computes logits
    • computes loss and step accuracy
    • prepares next step tokens
    • Updates block mask

Layer definitions

Files:

  • src/speculators/train/eagle3/model_definitions.py

Currently just contains model definitions for llama3 style draft models. Supports norm_before_residual=True or False. Attempted to keep modifications to the original llama models minimal.

Distributed training via FSDP

Files:

  • src/speculators/train/utils.py
  • src/speculators/train/checkpointer.py
  • src/speculators/train/trainer.py (setup_model fn)

Full support for FSDP training by initializing the training script with torchrun --nnodes --nproc_per_node=N where N is the number of gpus. Tested with N=2,3,4, 8 and all work. FSDP training also enables Automatic Mixed Precision (AMP) for improved performance.

checkpointer.py contains checkpointing logic for FSDP distributed model weights (gather all weights on rank 0 before saving).

Note: the way distributed works in general is N copies of the script are started and all run the same code but with some env variables setting which lets each process know its rank. Then explicit dist.barrier() calls or implicit calls within FSDP forward/backwards hooks force each process to wait until they all reach the same point in the code, before continuing. It is important that all ranks reach these operations as it allows them to perform synchronized operations (such as gathering, reducing, etc). However, we can also limit certain code to only one rank (rank 0) so that we only log once, or save to checkpoint once, using simple if local_rank == 0 statements.

Logging

Files:

  • src/speculators/train/logger.py
  • scripts/train.py: (setup logger calls at start of main())
  • src/speculators/train/trainer.py and other files: usage of metric_logger and root_logger

Another implementation mostly copied from prior work I did on instructlab/training. This uses python's std library logging module and extends it to support training metric logging. We can log a nested dict of metrics anywhere in the codebase like so:

# Setup once
import logging
metric_logger = logging.getLogger("speculators.metrics")

# Log call
metric_logger.info(
    {"train": {"loss": loss.item(), **acc_values}, "epoch": epoch},
    extra={"step": self.global_step},
)

And when the user runs the training script they can select one (or multiple) of tensorboard, wandb, and trackio and the results will be logged to the respective experiment tracker.

There is also a root_logger which can be used for regular update logging and everything logged to either the root_logger or metric_logger will be pretty-printed to console.

Trainer

Files:

  • src/speculators/train/trainer.py

The Trainer class is initialized with the model, data loaders, and a config and:

  • Sets up model / optimizer (loads weights and configures distributed if needed)
  • Contains the training and validation loops (train_epoch and val_epoch respectively)
  • And the overall training loop which alternatives between training, validation, and saving checkpoints

Todos:

  • Eagle3Draft Model definition with TTT steps and loss calculations
  • Patched Decoder layer definitions
  • Simple data loading from sample files
  • FlexAttention masking and implementation
  • Loss Masking
  • Training loop
    • Train data loader
    • loss.backward() + optimizer steps
    • Distributed loss reduction
    • Val data loader
    • Metric collection/reporting
    • Model checkpointing
  • Data batching
    • Collate fn
    • Batch sampler (dynamic batch size through sample packing)
    • Distributed (rank) aware sampling
  • Distributed support
  • Code relocation / merging with existing definitions (Currently just have everything under speculators/train but this will need to change) FUTURE PR
  • Verify correctness of key components (attention masking, data token alignment, etc).
  • General testing

Essential todos (as of 10/22/2025):

  • Save checkpoints to safetensors format w/ required config info
  • Implement save best or save last logic (currently saving every epoch) FUTURE PR
  • Better Verifier lm_head, embed_tokens loading (requires added loading util for specific layers #144)
  • Eagle3DraftModel.__init__ signature cleanup/better configuration
  • Config/argparsing for scripts/train.py FUTURE PR
  • Ensure flex attention impl works with torch==2.9 and torch.compile
  • Fix lint / quality / type errors and pass CI

@github-actions
Copy link

github-actions bot commented Oct 3, 2025

📦 Build Artifacts Available
The build artifacts (`.whl` and `.tar.gz`) have been successfully generated and are available for download: https://github.com/vllm-project/speculators/actions/runs/19339085320/artifacts/4558708403.
They will be retained for up to 30 days.
Commit: 45bf922

@fynnsu fynnsu force-pushed the eagle3_training branch 4 times, most recently from 33b96a6 to 3d12f28 Compare October 8, 2025 21:30
@fynnsu fynnsu force-pushed the eagle3_training branch 10 times, most recently from 2df7e2c to 129adb3 Compare October 23, 2025 21:38
@fynnsu fynnsu changed the title [WIP] Eagle3 Training Implementation Eagle3 Training Oct 24, 2025
@fynnsu fynnsu marked this pull request as ready for review October 24, 2025 17:05
@fynnsu fynnsu requested a review from eldarkurtic October 24, 2025 18:53
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work! In future, consider splitting a PR like this up into separate smaller PRs that can be merged over time. This looks like it could be split up into a few -- logging, trainer, dataset class, llama-specific code.

A few comments from an outsider's perspective. Since this is all entirely new, i'm sure there will be some validation after this lands, but consider an e2e test here or in a follow-up

Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super beautiful, really great job.

The only part that makes me slightly nervous is that reimplementation of model definitions, and the fixed ModelComponents structure, as this pattern makes supporting new models harder/ more rigid. If there's any way to make this more general/ provide a good programming model, that would be nice. Otherwise, this is good for now.

@fynnsu
Copy link
Collaborator Author

fynnsu commented Oct 28, 2025

The only part that makes me slightly nervous is that reimplementation of model definitions, and the fixed ModelComponents structure, as this pattern makes supporting new models harder/ more rigid. If there's any way to make this more general/ provide a good programming model, that would be nice. Otherwise, this is good for now.

Yeah that's fair. This is also my least favorite part :(

Unfortunately, we do need to make a few changes to the DecoderLayer but I did my best to minimize these and clearly mark them with comments. I'm happy to look further into methods for removing these modifications but right now it isn't super clear to me how we could do that.

As for ModelComponents, this is really just meant as a super simple temporary way to bundle up all the classes associated with a model architecture. It's possible we will need to extend or modify this in the future but it feels premature to do so before we even have a second drafter architecture that we would want to support.

Supporting other drafter architectures for spec decoding is less critical than supporting other model architectures is because we don't need to match the drafter architecture with the verifier's. e.g. there's nothing stopping you from training a Llama drafter model on a Qwen verifier.

@fynnsu fynnsu force-pushed the eagle3_training branch 2 times, most recently from 2638cff to b075197 Compare October 30, 2025 21:58
- Only load files ending with `pt`
- Enforce loading on cpu


Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Although the object is a `BlockMask`, rename it to `attention_mask` so that the naming aligns with what is used in the transformer components (DecoderBlock and attention fn)


Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
In the Eagle3 Algorithm, the first layer needs to be modified to support a larger (2x) hidden dim, while subsequent layers behave as regular.
Previously, we used a special decoder layer class that behaved differently depending on the `layer_idx` it received. Now we instead, use the 
special class only for the first layer and switch to using the original `LlamaDecoderLayer` for subsequent layers. 


Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Simplifies the code, while removing the option for an non-zero Gaussian transform mean.

Also changes default values for standard deviation and fixes scaling issue for uniform transform


Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
dsikka
dsikka previously approved these changes Nov 13, 2025
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
@fynnsu fynnsu dismissed stale reviews from dsikka and HDCharles via 45bf922 November 13, 2025 16:48
@dsikka dsikka enabled auto-merge (squash) November 13, 2025 16:58
@dsikka dsikka merged commit 8a2d483 into main Nov 13, 2025
24 checks passed
@dsikka dsikka deleted the eagle3_training branch November 13, 2025 17:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants